#!/usr/bin/env python3
from __future__ import annotations
import os, typing as t, sys
from pathlib import Path

_ROOT = Path(__file__).parent.parent

sys.path.insert(0, (_ROOT / 'openllm-core' / 'src').__fspath__())
sys.path.insert(1, (_ROOT / 'openllm-python' / 'src').__fspath__())
from openllm_core._typing_compat import LiteralBackend
from openllm.models import auto
from openllm import CONFIG_MAPPING

if t.TYPE_CHECKING: from collections import OrderedDict

config_requirements = {k: [_.replace('-', '_') for _ in v.__openllm_requirements__] if v.__openllm_requirements__ else None for k, v in CONFIG_MAPPING.items()}
_dependencies: dict[LiteralBackend, str] = {k: v for k, v in zip(LiteralBackend.__args__[:-2], ('torch', 'tensorflow', 'flax', 'vllm'))}
_auto: dict[str, str] = {k: v for k, v in zip(LiteralBackend.__args__[:-2], ('AutoLLM', 'AutoTFLLM', 'AutoFlaxLLM', 'AutoVLLM'))}

def get_target_dummy_file(backend: LiteralBackend) -> Path:
  return _ROOT / 'openllm-python' / 'src' / 'openllm' / 'utils' / f'dummy_{backend}_objects.py'

def mapping_names(backend: LiteralBackend):
  return 'MODEL_MAPPING_NAMES' if backend == 'pt' else f'MODEL_{backend.upper()}_MAPPING_NAMES'

def get_mapping(backend: LiteralBackend) -> OrderedDict[t.Any, t.Any]:
  return getattr(auto, mapping_names(backend))

def make_class_stub(model_name: str, backend: LiteralBackend, indentation: int = 2, auto: bool = False) -> list[str]:
  _dep_list: list[str] = [
      f'"{v}"' for v in [_dependencies[backend], *(t.cast(t.List[str], config_requirements[model_name]) if model_name != '__default__' and config_requirements[model_name] else [])]
  ]
  if auto: cl_ = _auto[backend]
  else: cl_ = get_mapping(backend)[model_name]
  lines = [
      f'class {cl_}(metaclass=_DummyMetaclass):', ' ' * indentation + f"_backends=[{','.join(_dep_list)}]",
      ' ' * indentation + f"def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,[{','.join(_dep_list)}])"
  ]
  return lines

def write_stub(backend: LiteralBackend, _path: str) -> list[str]:
  base = [
      f'# This file is generated by {_path}. DO NOT EDIT MANUALLY!', f'# To update this, run ./{_path}', 'from __future__ import annotations', 'import typing as _t',
      'from openllm_core.utils import DummyMetaclass as _DummyMetaclass, require_backends as _require_backends',
  ]
  base.extend([v for it in [make_class_stub(k, backend) for k in get_mapping(backend)] for v in it])
  # autoclass
  base.extend(make_class_stub('__default__', backend, auto=True))
  # mapping and export
  _imports = [f'"{v}"' for v in get_mapping(backend).values()]
  base += [f'{mapping_names(backend)}:_t.Any=None', f"__all__:list[str]=[\"{mapping_names(backend)}\",\"{_auto[backend]}\",{','.join(_imports)}]\n"]
  return base

def main() -> int:
  _path = os.path.join(os.path.basename(os.path.dirname(__file__)), os.path.basename(__file__))
  for backend in _dependencies:
    with get_target_dummy_file(backend).open('w') as f:
      f.write('\n'.join(write_stub(backend, _path)))
  return 0

if __name__ == '__main__': raise SystemExit(main())
